#!/bin/env python
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

#----------
# Read Data
#----------
#---Observed 
datafl1 = "./Data/anomalousdays_observations.json"
alldata1 = pd.read_json(datafl1, typ='series')
years = alldata1.index
obsdata = alldata1.to_numpy()


#---HiFLOR AllForc
#1971-2050
datafl2 = "./Data/anomalousdays_HiFLOR_AllForc.json"
alldata2 = pd.read_json(datafl2)
years2a = alldata2.index
alldata2a = alldata2.to_numpy()
ensmean2a = alldata2a.mean(axis=1)
ensmax2a = alldata2a.max(axis=1)
ensmin2a = alldata2a.min(axis=1)

#---HiFLOR NatForc
#1971-2020
datafl3 = "./Data/anomalousdays_HiFLOR_NatForc.json"
alldata3 = pd.read_json(datafl3)
years3a = alldata3.index
alldata3a = alldata3.to_numpy()
ensmean3a = alldata3a.mean(axis=1)
ensmax3a = alldata3a.max(axis=1)
ensmin3a = alldata3a.min(axis=1)

#---SPEAR AllForc
#1971-2100
datafl4 = "./Data/anomalousdays_SPEAR_AllForc.json"
alldata4 = pd.read_json(datafl4)
years4a = alldata4.index
alldata4a = alldata4.to_numpy()
ensmean4a = alldata4a.mean(axis=1)
ensmax4a = alldata4a.max(axis=1)
ensmin4a = alldata4a.min(axis=1)

#---SPEAR NatForc
#1971-2100
datafl5 = "./Data/anomalousdays_SPEAR_NatForc.json"
alldata5 = pd.read_json(datafl5)
years5a = alldata5.index
alldata5a = alldata5.to_numpy()
ensmean5a = alldata5a.mean(axis=1)
ensmax5a = alldata5a.max(axis=1)
ensmin5a = alldata5a.min(axis=1)

#---------
# Subroutnes
#---------
def linreg(X,Y):
    """
     return a,b in solution to y = ax + b such that root mean square distance
     between trend line and original points is minimized
    """
    N = len(X)
    Sx = Sy = Sxx = Syy = Sxy = 0.0

    for x, y in zip(X, Y):
        Sx = Sx + x
        Sy = Sy + y
        Sxx = Sxx + x*x
        Syy = Syy + y*y
        Sxy = Sxy + x*y
    det = Sxx * N - Sx * Sx
    aa = (Sxy * N - Sy * Sx)/det
    bb = (Sxx * Sy - Sx * Sxy)/det
    YY2 = aa * X + bb
    return X, YY2, aa

def mk_test(x, alpha = 0.05):
    from scipy.stats import norm
    """
    this perform the MK (Mann-Kendall) test to check if the trend is present in
    data or not

    Input:
        x:   a vector of data
        alpha: significance level

    Output:
        h: True (if trend is present) or False (if trend is absence)
        p: p value of the sifnificance test

    Examples
    --------
      >>> x = np.random.rand(100)
      >>> h,p = mk_test(x,0.05)  # meteo.dat comma delimited
    """
    n = len(x)

    # calculate S
    s = 0
    for k in range(n-1):
        for j in range(k+1,n):
            s += np.sign(x[j] - x[k])

    # calculate the unique data
    unique_x = np.unique(x)
    g = len(unique_x)

    # calculate the var(s)
    if n == g: # there is no tie
        var_s = (n*(n-1)*(2*n+5))/18
    else: # there are some ties in data
        tp = np.zeros(unique_x.shape)
        for i in range(len(unique_x)):
            tp[i] = sum(unique_x[i] == x)
        var_s = (n*(n-1)*(2*n+5) + np.sum(tp*(tp-1)*(2*tp+5)))/18

    if s>0:
        z = (s - 1)/np.sqrt(var_s)
    elif s == 0:
        z = 0
    elif s<0:
        z = (s + 1)/np.sqrt(var_s)

    # calculate the p_value
    p = 2*(1-norm.cdf(abs(z))) # two tail test
    h = abs(z) > norm.ppf(1-alpha/2)

    return h, p

def cal_trend_pval(syear,eyear,years,data, factor=1.0):
    idx = np.where((years>=syear)&(years<=eyear))
    year1_ = years[idx[0]]
    data1_ = data[idx[0]]
    h1, pval1 = mk_test(data1_, alpha=0.05)
    xx1, yy1, aa1 = linreg(year1_,data1_)
    aa1 = aa1*factor
    return aa1, pval1

def cal_trends(syear,eyear,years,data,factor=1.0):
    tmax,emax = np.shape(data)
    trends = np.zeros(emax)
    pvals = np.zeros(emax)
    for ee in range(emax):
        trends[ee], pvals[ee] = cal_trend_pval(syear,eyear,years,data[:,ee])
        trends[ee] = trends[ee] * factor
    return trends, pvals

#----------
# Compute trends
#----------
trend ={}
pval ={}
trend_data ={}

#--Observed
trend['Obs'], pval['Obs'] = cal_trend_pval(1977,2015,years,obsdata)
trend_data['Obs'] = [trend['Obs']]

trend['dummy'] = np.nan
trend_data['dummy'] = [np.nan]

#--HiFLOR_AllForc
trend['HiFLOR_all_p'], pval['HiFLOR_all_p'] = cal_trend_pval(1977,2015,years2a,ensmean2a)
trend_data['HiFLOR_all_p'], dummy = cal_trends(1977,2015,years2a,alldata2a)

trend['HiFLOR_all_f'], pval['HiFLOR_all_f'] = cal_trend_pval(1977,2050,years2a,ensmean2a)
trend_data['HiFLOR_all_f'], dummy = cal_trends(1977,2050,years2a,alldata2a)

#--HiFLOR_NatForc
trend['HiFLOR_nat_p'], pval['HiFLOR_nat_p'] = cal_trend_pval(1977,2015,years3a,ensmean3a)
trend_data['HiFLOR_nat_p'], dummy = cal_trends(1977,2015,years3a,alldata3a)

#--SPEAR_AllForc
trend['SPEAR_all_p'], pval['SPEAR_all_p'] = cal_trend_pval(1977,2015,years4a,ensmean4a)
trend_data['SPEAR_all_p'], dummy = cal_trends(1977,2015,years4a,alldata4a)

trend['SPEAR_all_f'], pval['SPEAR_all_f'] = cal_trend_pval(1977,2050,years4a,ensmean4a)
trend_data['SPEAR_all_f'], dummy = cal_trends(1977,2050,years4a,alldata4a)

#--SPEAR_NatForc
trend['SPEAR_nat_p'], pval['SPEAR_nat_p'] = cal_trend_pval(1977,2015,years5a,ensmean5a)
trend_data['SPEAR_nat_p'], dummy = cal_trends(1977,2015,years5a,alldata5a)

trend['SPEAR_nat_f'], pval['SPEAR_nat_f'] = cal_trend_pval(1977,2050,years5a,ensmean5a)
trend_data['SPEAR_nat_f'], dummy = cal_trends(1977,2050,years5a,alldata5a)


num1 = 0
num2 = 0
for exp in ["Obs","HiFLOR_all_p","SPEAR_all_p","HiFLOR_nat_p","SPEAR_nat_p",'dummy',"HiFLOR_all_f","SPEAR_all_f","SPEAR_nat_f"]:
    num1 = num1 + len(trend_data[exp])
    num2 = num2+1
    
columns = ["Experiment", "Trends"]
index1 = np.arange(num1)
index2 = np.arange(num2)
df = pd.DataFrame(columns=columns, index=index1)
dfmean = pd.DataFrame(columns=columns, index=index2)

colors = []
ccolors = {}
ii=0
jj=0
for exp in ["Obs","HiFLOR_all_p","SPEAR_all_p","HiFLOR_nat_p","SPEAR_nat_p",'dummy',"HiFLOR_all_f","SPEAR_all_f","SPEAR_nat_f"]:
          
     if exp == "dummy":
         colors.append("white")
         ccolors[exp] = "white"   
     elif pval[exp] <=0.05 and trend[exp]>0:
         colors.append("red")
         ccolors[exp] = "red"   
     elif pval[exp] <=0.05 and trend[exp]<0:
         colors.append("blue")
         ccolors[exp] = "blue"   
     else:
         colors.append("gray")
         ccolors[exp] = "gray"   
        
     for tt in trend_data[exp]:        
           if exp=="Obs" or exp=='dummy':
               df.iloc[ii] = [exp, np.nan] # dummy
           else:
               df.iloc[ii] = [exp, tt]
           ii=ii+1
            
     dfmean.iloc[jj] = [exp, trend[exp]]
     jj=jj+1

    
labs = ["Observations (1977\u20142015)","HiFLOR (AllForc, 1977\u20142015)", "SPEAR (AllForc, 1977\u20142015)",'HiFLOR (NatForc, 1977\u20142015)',"SPEAR (NatForc, 1977\u20142015)",
        "", "HiFLOR (AllForc, 1977\u20142050)", "SPEAR (AllForc, 1977\u20142050)","SPEAR (NatForc, 1977\u20142050)"]

#----------
# Plot
#----------
fig, ax = plt.subplots(figsize=(8.0,5.0))


ax=sns.boxplot(y="Experiment",x="Trends", data=df, whis=1.5, color="c", palette=ccolors, boxprops=dict(alpha=0.3))
sns.stripplot(y="Experiment",x="Trends", data=df, jitter=True, size=4, color=".3",linewidth=0)
sns.stripplot(y="Experiment",x="Trends", data=dfmean, marker="s", jitter=False, size=14, palette=colors, linewidth=0)

ax.set_yticklabels(labs)

ax.set_title('(b) Linear Trend in Anomalous Days', fontsize=16)

ax.set_xlabel('Days per year', fontsize=14)
ax.axvline(x=0, linestyle='dashed', color="gray")

ax.set_ylabel("")

plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
#
plt.tight_layout()
plt.savefig("./Fig4b.png")
